import numpy as np
from srunner.scenariomanager.carla_data_provider import CarlaDataProvider

class Captioner:
    def __init__(self, 
                 use_llm=False):
        """
        Initialize states of interested entities
        self.road: description of the current road of the ego vehicle
        self.vehicles: descrption of information of other vehicles
        self.traffic_light: state of the traffic_light of interest
        self.route: description of the next 10 waypoints
        self.control: the current control that ego vehicle is going to apply
        """
        # TODO: maintain a short-term memory?
        self.road = None
        self.vehicle = None
        self.traffic_light = None
        self.route = None
        self.control = None
        self.caption = ''
        self.trajectory = []
        self.trajectory_length = 20
        self.use_llm = use_llm

    def update(self,
               ego_vehicle=None,
               road=None,
               vehicle=None,
               traffic_light=None,
               route=None):
        self.ego_vehicle = ego_vehicle
        self.road = road
        self.vehicle = vehicle
        self.traffic_light = traffic_light
        self.route = route
        #self.trajectory.append([vehicle[6],vehicle[5]])
        if len(self.trajectory) > self.trajectory_length:
            self.trajectory.pop(0)
        #self.caption = self._rule_based_caption()
        self.caption = self._simple_caption()
        if self.use_llm: self.caption = self._llm_summarize()

    def caption(self):
        return self.caption

    def _llm_summarize(self):
        # TODO: call openai api to summarize the caption
        return caption
    
    def _simple_caption(self):
        desc = ""
        desc += "Ego car speed: {} km/h\n".format(get_speed(self.ego_vehicle.get_velocity()))
        desc += "Ego car speed limit: {} km/h\n".format(self.road[5])
        desc += "Ego car lane id: {} \n".format(abs(self.road[6].lane_id))
        if self.ego_vehicle.is_at_traffic_light():
            desc += "Ego car at traffic light: {} \n".format(self.ego_vehicle.get_traffic_light_state())
        for vehicle in self.vehicle:
            desc += "Vehicle {} traveling at {} km/h, and is {} meters {} the ego car and {} meters to the {} of the ego car. It is on the lane {}.".format(
                   vehicle[0], get_speed(vehicle[4]),
                   np.round(abs(vehicle[2]),2), "ahead of" if vehicle[2] > 0 else "behind",  
                   np.round(abs(vehicle[3]),2), "right" if vehicle[3] > 0 else "left",
                   abs(vehicle[7].lane_id))
            desc += "\n"
        desc += "Possible Lane Change direction: {} \n".format(self.road[1])
        return desc

    def _rule_based_caption(self):
        # Describe road
        speed_limit = self.road[5]
        speed_unit = "km/h"
        road_desc = "I am an autonomous vehicle driving on a road with speed upper limit {} {}. ".format(
                    speed_limit,
                    speed_unit,
                    )
        ego_wp = self.road[6]

        # Describe vehicles
        vehicle_desc = "There are {} vehicles near me.\n".format(len(self.vehicle))
        
        for vehicle in self.vehicle:
            curr_veh_desc = "Vehicle {} is {} meters {} me and {} meters to the {} of me. ".format(vehicle[0], 
                            np.round(abs(vehicle[2]),2), "ahead of" if vehicle[2] > 0 else "behind",  
                            np.round(abs(vehicle[3]),2), "right" if vehicle[3] > 0 else "left")
            vehicle_wp = vehicle[7]
            if ego_wp.road_id == vehicle_wp.road_id:
                if ego_wp.lane_id == vehicle_wp.lane_id:
                    curr_veh_desc += "The vehicle is on the same lane as me. "
                elif np.sign(ego_wp.lane_id) != np.sign(vehicle_wp.lane_id):
                    curr_veh_desc += "The vehicle is traveling on the opposite direction. "
                else:
                    curr_veh_desc += "The vehicle is traveling in the same direction and is {} lanes to the {} of me. ".format(
                            abs(ego_wp.lane_id - vehicle_wp.lane_id),
                            "left" if np.sign(ego_wp.lane_id - vehicle_wp.lane_id) > 0 else "right")

            vehicle_desc += curr_veh_desc + '\n'
        
        # Describe routes
        route_desc = "My plan next is to"
        for wp in self.route:
            curr_route_desc = "go to {} meters ahead of me, and {} meters to the {} of me, ".format(
                              np.round(abs(wp[0]),2), 
                              np.round(abs(wp[1]),2),
                              "right" if wp[1] > 0 else "left")
            for lm in wp[4]:
                if lm.type == "274":
                    if lm.value > speed_limit and lm.value > 90:
                        curr_route_desc += "speeding up to merge onto a highway, "
                    if lm.value < speed_limit and lm.value < 90:
                        curr_route_desc += "slowing down to exit the highway"
                    speed_limit = lm.value
                    speed_unit = "km/h" #lm.unit TODO: check
            curr_spdlimit_desc = "which has a speed limit {} {}, then I will ".format(
                                 speed_limit,
                                 speed_unit,
                                 )
            curr_route_desc += curr_spdlimit_desc
            route_desc += curr_route_desc 
        route_desc += " plan further."
        #if self.road[5] > 30:
        #    from IPython import embed; embed()
        caption = road_desc + '\n' + vehicle_desc + '\n'+ route_desc + '\n'
        traj_desc = ""
        wp_traj, spd_limit = [], []
        for wp in self.trajectory:
            wp_traj.append(wp[0])
            spd_limit.append(wp[1])

        return caption


def get_speed(velocity):
    return np.round(np.sqrt(velocity.x**2 + velocity.y**2 + velocity.z**2),2)
